{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "79a14c63-a085-493f-8db9-6af3e1d744b5",
   "metadata": {},
   "source": [
    "### `MultiNodeWeightedSampler` example\n",
    "In this notebook, we will explore the usage of `MultiNodeWeightedSampler` in `torchdata.nodes`.\n",
    "\n",
    "`MultiNodeWeightedSampler` allows us to sample with a probability from multiple datsets. We will make three datasets, and then see how does the composition of the output depends on the weights defined in the `MultiNodeWeightedSampler`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0b283748-9b3f-4b9e-bbc5-db0791f4d900",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchdata.nodes import Mapper, MultiNodeWeightedSampler, IterableWrapper, Loader\n",
    "import collections\n",
    "\n",
    "# defining a simple map_fn as a place holder example\n",
    "def map_fn(item):\n",
    "    return {\"x\":item}\n",
    "\n",
    "\n",
    "def constant_stream(value: int):\n",
    "  while True:\n",
    "    yield value\n",
    "\n",
    "# First, we create a dictionary of three datasets, with each dataset converted into BaseNode using the IterableWrapper\n",
    "num_datasets = 3\n",
    "datasets = {\n",
    "    \"ds0\": IterableWrapper(constant_stream(0)),\n",
    "    \"ds1\": IterableWrapper(constant_stream(1)),\n",
    "    \"ds2\": IterableWrapper(constant_stream(2)),\n",
    "}\n",
    "\n",
    "# Next, we have to define weights for sampling from a particular dataset\n",
    "weights = {\"ds0\": 0.5, \"ds1\": 0.25, \"ds2\": 0.25}\n",
    "\n",
    "# Finally we instatiate the MultiNodeWeightedSampler to sample from our datasets\n",
    "multi_node_sampler = MultiNodeWeightedSampler(datasets, weights)\n",
    "\n",
    "# Since nodes are iterators, they need to be manually .reset() between epochs.\n",
    "# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n",
    "loader = Loader(multi_node_sampler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "77784ba3-b917-4083-aed4-dba2374110d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fractions = {0: 0.49791, 2: 0.25067, 1: 0.25142}\n",
      "The original weights were = {'ds0': 0.5, 'ds1': 0.25, 'ds2': 0.25}\n"
     ]
    }
   ],
   "source": [
    "# Let's take a look at the output for 100k numbers, compute the fraction of each dataset in that batch\n",
    "# and compare the batch composition with our given weights\n",
    "n = 100000\n",
    "it = iter(loader)\n",
    "samples = [next(it) for _ in range(n)]\n",
    "fractions = {k: v/len(samples) for k, v in collections.Counter(samples).items()}\n",
    "print(f\"fractions = {fractions}\")\n",
    "print(f\"The original weights were = {weights}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
